在前幾天的內容中,我們談到了AI模型的運作與更新方式,也介紹了Pytorch這項好用的工具。在昨天更是看到了AI形模型是如何模擬人腦的運作。今明兩天,我們將利用pytorch展示如何從頭開始建立自己的AI模型。
在Pytorch中有三種常見的方式可以獲取我們所需要的資料集,下面會根據使用難度,由易到難依序展開介紹:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
print(train_dataset.__getitem__(10))
print(train_dataset.__len__)
transform
,這個動作的目的在於對所有資料集中的資料「統一進行某種操作」,操作可以是型態轉換(tensor跟numpy資料型態互換)、基礎影像處理(crop, normalize等操作)或是自定義的其餘操作等等,簡單來說就是一個可以快速調整資料的方式,這個部分更多的內容會在下面的補充章節中討論。.__lne__()
,一個是.__getitem__(index)
,前者可以告訴我們這個資料集的大小,後者則是在給定特定index的情況下,告訴你這筆資料裡面包含了甚麼,通常是(影像,標籤)這樣的格式。除了直接使用Pytorch中提供的影像資料以外,有時候,我們想要在自己的影像資料上訓練自己的模型,這個時候我們可以下面兩種方式:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
__len__()
與__getitem__()
這兩個函式,最上面的__init__
函式是輔助作用,在我們將一個這樣的資料集實例化之後,第一個會執行的函式,所以可以在這邊定義一些後續會需要用到的東西。.__lne__()
與.__getitem__(index)
這兩個函式,必須要回傳資料集的大小以及特定index的那筆資料所包含的東西。以上面的程式碼為例,我們將所有影像的資訊透過csv檔讀進來,並且保存在self.img_labels
中,接著在def __getitem__(self, idx)
裡面,我們剛剛存下來的影像資訊,把每張影像的「路徑」與「標註(Label)」都取出來,分別用image, label
表示。torch.utils.data
中的DataLoader
來幫助我們將已經建立好的dataset包裝成一批一批的資料,就像Pytorch中提供的範例:from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
.__getitem__(index)
函式來回傳資料),後者則是一批一批的資料。